-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Add vector.shuffle
tree transformation
#145740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This PR adds a new transformation that turns sequences of `vector.to_elements` and `vector.from_elements` into a binary tree of `vector.shuffle` operations. (Related RFC: https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779). Example: ``` %0:4 = vector.to_elements %a : vector<4xf32> %1:4 = vector.to_elements %b : vector<4xf32> %2:4 = vector.to_elements %c : vector<4xf32> %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 : vector<12xf32> ==> %0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> %1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32> %2 = vector.shuffle %0, %1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> ``` The algorithm leverages the structured extraction/insertion information of `vector.to_elements` and `vector.from_elements` operations and builds a set of intervals to determine the vector length that should be used at each level of the tree. There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along.
@llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesThis PR adds a new transformation that turns sequences of Example:
The algorithm leverages the structured extraction/insertion information of There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along. Patch is 62.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145740.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 14cff4ff893b5..6761cd65c5009 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -297,6 +297,13 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
/// n > 1.
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+/// Populate patterns to rewrite sequences of `vector.to_elements` +
+/// `vector.from_elements` operations into a tree of `vector.shuffle`
+/// operations.
+void populateVectorToFromElementsToShuffleTreePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..959c2fbf31f1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..9431a4d8e240f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,9 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
];
}
+def LowerVectorToFromElementsToShuffleTree
+ : Pass<"lower-vector-to-from-elements-to-shuffle-tree", "func::FuncOp"> {
+ let summary = "Lower `vector.to_elements` and `vector.from_elements` to a tree of `vector.shuffle` operations";
+}
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..9e287fc109990 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
+ LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
SubsetOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
new file mode 100644
index 0000000000000..53728d6dbe2a3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -0,0 +1,692 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+/// 1. Vectors are shuffled in pairs by order of appearance in
+/// the `vector.from_elements` operand list.
+/// 2. Each input vector to each level is used only once.
+/// 3. The number of levels in the tree is:
+/// ceil(log2(# `vector.to_elements` ops)).
+/// 4. Vectors at each level of the tree have the same vector length.
+/// 5. Vector positions that do not need to be shuffled are represented with
+/// poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+/// %0:4 = vector.to_elements %a : vector<4xf32>
+/// %1:4 = vector.to_elements %b : vector<4xf32>
+/// %2:4 = vector.to_elements %c : vector<4xf32>
+/// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+/// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+/// : vector<12xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+/// : vector<4xf32>, vector<4xf32>
+/// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+/// : vector<4xf32>, vector<4xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+/// 6, 7, 8, 9, 10, 11]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * The shuffle tree has two levels:
+/// - Level 1 = (%shuffle0, %shuffle1)
+/// - Level 2 = (%result)
+/// * `%a` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%0#0` and `%1#0`).
+/// * `%c` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 16,
+/// respectively.
+/// * `%shuffle1` uses poison values to match the vector length of its
+/// tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+/// : vector<5xf32>, vector<5xf32>
+/// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+/// : vector<5xf32>, vector<5xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * `%c` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%2#2` and `%1#1`).
+/// * `%a` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 9,
+/// respectively.
+/// * `%shuffle0` uses poison values to mark unused vector positions and
+/// match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+///
+class VectorShuffleTreeBuilder {
+public:
+ VectorShuffleTreeBuilder() = delete;
+ VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+ ArrayRef<ToElementsOp> toElemDefs);
+
+ /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+ /// and compute the shuffle tree configuration. This method does not generate
+ /// any IR.
+ LogicalResult computeShuffleTree();
+
+ /// Materialize the shuffle tree configuration computed by
+ /// `computeShuffleTree` in the IR.
+ Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+ // IR input information.
+ FromElementsOp fromElementsOp;
+ SmallVector<ToElementsOp> toElementsDefs;
+
+ // Shuffle tree configuration.
+ unsigned numLevels;
+ SmallVector<unsigned> vectorSizePerLevel;
+ /// Holds the range of positions in the final output that each vector input
+ /// in the tree is contributing to.
+ SmallVector<SmallVector<Interval>> inputIntervalsPerLevel;
+
+ // Utility methods to compute the shuffle tree configuration.
+ void computeInputVectorIntervals();
+ void computeOutputVectorSizePerLevel();
+
+ /// Dump the shuffle tree configuration.
+ void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+ FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+ : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
+
+ assert(fromElementsOp && "from_elements op is required");
+ assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+
+ // Duplicate the last vector if the number of `vector.to_elements` is odd to
+ // simplify the shuffle tree algorithm.
+ if (toElementsDefs.size() % 2 != 0) {
+ toElementsDefs.push_back(toElementsDefs.back());
+ }
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Analysis Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the intervals for all the input vectors in the shuffle tree. The
+/// interval of an input vector is the range of positions in the final output
+/// that the input vector contributes to.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the
+/// number of inputs even) so we compute the interval for each input vector:
+///
+/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
+/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7]
+/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8]
+/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8]
+///
+/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so
+/// we compute the intervals for each input vector to level 1 as:
+/// * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7]
+/// * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8]
+///
+void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
+ // Map `vector.to_elements` ops to their ordinal position in the
+ // `vector.from_elements` operand list. Make sure duplicated
+ // `vector.to_elements` ops are mapped to the its first occurrence.
+ DenseMap<ToElementsOp, unsigned> toElementsToInputOrdinal;
+ for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs))
+ toElementsToInputOrdinal.insert({toElementsOp, idx});
+
+ // Compute intervals for each input vector in the shuffle tree. The first
+ // level computation is special-cased to keep the implementation simpler.
+
+ SmallVector<Interval> firstLevelIntervals(toElementsDefs.size(),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (const auto &[idx, element] :
+ llvm::enumerate(fromElementsOp.getElements())) {
+ auto toElementsOp = cast<ToElementsOp>(element.getDefiningOp());
+ unsigned inputIdx = toElementsToInputOrdinal[toElementsOp];
+ Interval ¤tInterval = firstLevelIntervals[inputIdx];
+
+ // Set lower bound to the first occurrence of the `vector.to_elements`.
+ if (currentInterval.first == kMaxUnsigned)
+ currentInterval.first = idx;
+
+ // Set upper bound to the last occurrence of the `vector.to_elements`.
+ currentInterval.second = idx;
+ }
+
+ // If the number of `vector.to_elements` is odd and the last op was
+ // duplicated, the interval for the duplicated op was not computed in the
+ // previous step as all the input occurrences were mapped to the original op.
+ // We copy the interval of the original op to the interval of the duplicated
+ // op manually.
+ if (firstLevelIntervals.back().second == kMaxUnsigned)
+ firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2);
+
+ inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals));
+
+ // Compute intervals for the remaining levels.
+ unsigned outputNumElements =
+ cast<VectorType>(fromElementsOp.getResult().getType()).getNumElements();
+ for (unsigned level = 1; level < numLevels; ++level) {
+ const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1];
+ SmallVector<Interval> currentLevelIntervals(
+ llvm::divideCeil(prevLevelIntervals.size(), 2),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size();
+ ++inputIdx) {
+ auto &interval = currentLevelIntervals[inputIdx];
+ const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
+ const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
+
+ // The interval of a vector at the current level is the union of the
+ // intervals of the two input vectors from the previous level being
+ // shuffled at this level.
+ interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first);
+ interval.second =
+ std::min(std::max(prevLhsInterval.second, prevRhsInterval.second),
+ outputNumElements - 1);
+ }
+
+ inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals));
+ }
+}
+
+/// Compute the uniform output vector size for each level of the shuffle tree,
+/// given the intervals of the input vectors at that level. The output vector
+/// size of a level is the size of the widest interval resulting from shuffling
+/// each pair of input vectors.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// Intervals:
+/// * Level 0: [0,6], [1,7], [2,8], [2,8]
+/// * Level 1: [0,7], [2,8]
+///
+/// Vector sizes:
+/// * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8,
+/// size_of([2,8] U [2,8] = [2,8]) = 7) = 8
+///
+/// * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9
+///
+void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() {
+ // Compute vector size for each level.
+ for (unsigned level = 0; level < numLevels; ++level) {
+ const auto ¤tLevelIntervals = inputIntervalsPerLevel[level];
+ unsigned currentVectorSize = 1;
+ for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) {
+ const auto &lhsInterval = currentLevelIntervals[i];
+ const auto &rhsInterval = currentLevelIntervals[i + 1];
+ unsigned combinedIntervalSize =
+ std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first +
+ 1;
+ currentVectorSize = std::max(currentVectorSize, combinedIntervalSize);
+ }
+ vectorSizePerLevel[level] = currentVectorSize;
+ }
+}
+
+void VectorShuffleTreeBuilder::dump() {
+ LLVM_DEBUG({
+ unsigned indLv = 0;
+
+ llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
+ ++indLv;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
+ ++indLv;
+ for (const auto &toElementsOp : toElementsDefs)
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n";
+ --indLv;
+
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Total levels: " << numLevels << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Vector sizes per level: [";
+ llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Input intervals per level:\n";
+ ++indLv;
+ for (const auto &[level, intervals] :
+ llvm::enumerate(inputIntervalsPerLevel)) {
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
+ << ": ";
+ llvm::interleaveComma(intervals, llvm::dbgs(),
+ [](const Interval &interval) {
+ llvm::dbgs() << "[" << interval.first << ","
+ << interval.second << "]";
+ });
+ llvm::dbgs() << "\n";
+ }
+ });
+}
+
+/// Compute the shuffle tree configuration for the given `vector.to_elements` +
+/// `vector.from_elements` input sequence. This method builds a balanced binary
+/// shuffle tree that combines pairs of input vectors at each level.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// build a tree that looks like:
+///
+/// %2 %1 %0 %0
+/// \ / \ /
+/// %2_1 = vector.shuffle %0_0 = vector.shuffle
+/// \ /
+/// %2_1_0_0 =vector.shuffle
+///
+/// The configuration comprises of computing the intervals of the input vectors
+/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and
+/// %2_1_0_0) and the output vector size for each level. For further details on
+/// intervals and output vector size computation, please, take a look at the
+/// corresponding utility functions.
+LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
+ // Initialize shuffle tree information based on its size.
+ assert(toElementsDefs.size() > 1 &&
+ "At least two 'vector.to_elements' ops are required");
+ numLevels = llvm::Log2_64(toElementsDefs.size());
+ vectorSizePerLevel.resize(numLevels, 0);
+ inputIntervalsPerLevel.reserve(numLevels);
+
+ computeInputVectorIntervals();
+ computeOutputVectorSizePerLevel();
+ dump();
+
+ return success();
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Code Generation Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the permutation mask for shuffling two input `vector.to_elements`
+/// ops. The permutation mask is the mapping of the input vector elements to
+/// their final position in the output vector, relative to the intermediate
+/// output vector of the `vector.shuffle` operation combining the two inputs.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// =>
+///
+/// // Level 0, vector length = 8
+/// %2_1 = PermutationShuffleMask(%2, %1) = [2,...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR adds a new transformation that turns sequences of Example:
The algorithm leverages the structured extraction/insertion information of There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along. Patch is 62.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145740.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 14cff4ff893b5..6761cd65c5009 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -297,6 +297,13 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
/// n > 1.
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+/// Populate patterns to rewrite sequences of `vector.to_elements` +
+/// `vector.from_elements` operations into a tree of `vector.shuffle`
+/// operations.
+void populateVectorToFromElementsToShuffleTreePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..959c2fbf31f1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..9431a4d8e240f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,9 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
];
}
+def LowerVectorToFromElementsToShuffleTree
+ : Pass<"lower-vector-to-from-elements-to-shuffle-tree", "func::FuncOp"> {
+ let summary = "Lower `vector.to_elements` and `vector.from_elements` to a tree of `vector.shuffle` operations";
+}
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..9e287fc109990 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
+ LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
SubsetOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
new file mode 100644
index 0000000000000..53728d6dbe2a3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -0,0 +1,692 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+/// 1. Vectors are shuffled in pairs by order of appearance in
+/// the `vector.from_elements` operand list.
+/// 2. Each input vector to each level is used only once.
+/// 3. The number of levels in the tree is:
+/// ceil(log2(# `vector.to_elements` ops)).
+/// 4. Vectors at each level of the tree have the same vector length.
+/// 5. Vector positions that do not need to be shuffled are represented with
+/// poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+/// %0:4 = vector.to_elements %a : vector<4xf32>
+/// %1:4 = vector.to_elements %b : vector<4xf32>
+/// %2:4 = vector.to_elements %c : vector<4xf32>
+/// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+/// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+/// : vector<12xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+/// : vector<4xf32>, vector<4xf32>
+/// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+/// : vector<4xf32>, vector<4xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+/// 6, 7, 8, 9, 10, 11]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * The shuffle tree has two levels:
+/// - Level 1 = (%shuffle0, %shuffle1)
+/// - Level 2 = (%result)
+/// * `%a` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%0#0` and `%1#0`).
+/// * `%c` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 16,
+/// respectively.
+/// * `%shuffle1` uses poison values to match the vector length of its
+/// tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+/// : vector<5xf32>, vector<5xf32>
+/// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+/// : vector<5xf32>, vector<5xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * `%c` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%2#2` and `%1#1`).
+/// * `%a` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 9,
+/// respectively.
+/// * `%shuffle0` uses poison values to mark unused vector positions and
+/// match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+///
+class VectorShuffleTreeBuilder {
+public:
+ VectorShuffleTreeBuilder() = delete;
+ VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+ ArrayRef<ToElementsOp> toElemDefs);
+
+ /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+ /// and compute the shuffle tree configuration. This method does not generate
+ /// any IR.
+ LogicalResult computeShuffleTree();
+
+ /// Materialize the shuffle tree configuration computed by
+ /// `computeShuffleTree` in the IR.
+ Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+ // IR input information.
+ FromElementsOp fromElementsOp;
+ SmallVector<ToElementsOp> toElementsDefs;
+
+ // Shuffle tree configuration.
+ unsigned numLevels;
+ SmallVector<unsigned> vectorSizePerLevel;
+ /// Holds the range of positions in the final output that each vector input
+ /// in the tree is contributing to.
+ SmallVector<SmallVector<Interval>> inputIntervalsPerLevel;
+
+ // Utility methods to compute the shuffle tree configuration.
+ void computeInputVectorIntervals();
+ void computeOutputVectorSizePerLevel();
+
+ /// Dump the shuffle tree configuration.
+ void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+ FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+ : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
+
+ assert(fromElementsOp && "from_elements op is required");
+ assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+
+ // Duplicate the last vector if the number of `vector.to_elements` is odd to
+ // simplify the shuffle tree algorithm.
+ if (toElementsDefs.size() % 2 != 0) {
+ toElementsDefs.push_back(toElementsDefs.back());
+ }
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Analysis Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the intervals for all the input vectors in the shuffle tree. The
+/// interval of an input vector is the range of positions in the final output
+/// that the input vector contributes to.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the
+/// number of inputs even) so we compute the interval for each input vector:
+///
+/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
+/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7]
+/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8]
+/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8]
+///
+/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so
+/// we compute the intervals for each input vector to level 1 as:
+/// * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7]
+/// * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8]
+///
+void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
+ // Map `vector.to_elements` ops to their ordinal position in the
+ // `vector.from_elements` operand list. Make sure duplicated
+ // `vector.to_elements` ops are mapped to the its first occurrence.
+ DenseMap<ToElementsOp, unsigned> toElementsToInputOrdinal;
+ for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs))
+ toElementsToInputOrdinal.insert({toElementsOp, idx});
+
+ // Compute intervals for each input vector in the shuffle tree. The first
+ // level computation is special-cased to keep the implementation simpler.
+
+ SmallVector<Interval> firstLevelIntervals(toElementsDefs.size(),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (const auto &[idx, element] :
+ llvm::enumerate(fromElementsOp.getElements())) {
+ auto toElementsOp = cast<ToElementsOp>(element.getDefiningOp());
+ unsigned inputIdx = toElementsToInputOrdinal[toElementsOp];
+ Interval ¤tInterval = firstLevelIntervals[inputIdx];
+
+ // Set lower bound to the first occurrence of the `vector.to_elements`.
+ if (currentInterval.first == kMaxUnsigned)
+ currentInterval.first = idx;
+
+ // Set upper bound to the last occurrence of the `vector.to_elements`.
+ currentInterval.second = idx;
+ }
+
+ // If the number of `vector.to_elements` is odd and the last op was
+ // duplicated, the interval for the duplicated op was not computed in the
+ // previous step as all the input occurrences were mapped to the original op.
+ // We copy the interval of the original op to the interval of the duplicated
+ // op manually.
+ if (firstLevelIntervals.back().second == kMaxUnsigned)
+ firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2);
+
+ inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals));
+
+ // Compute intervals for the remaining levels.
+ unsigned outputNumElements =
+ cast<VectorType>(fromElementsOp.getResult().getType()).getNumElements();
+ for (unsigned level = 1; level < numLevels; ++level) {
+ const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1];
+ SmallVector<Interval> currentLevelIntervals(
+ llvm::divideCeil(prevLevelIntervals.size(), 2),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size();
+ ++inputIdx) {
+ auto &interval = currentLevelIntervals[inputIdx];
+ const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
+ const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
+
+ // The interval of a vector at the current level is the union of the
+ // intervals of the two input vectors from the previous level being
+ // shuffled at this level.
+ interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first);
+ interval.second =
+ std::min(std::max(prevLhsInterval.second, prevRhsInterval.second),
+ outputNumElements - 1);
+ }
+
+ inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals));
+ }
+}
+
+/// Compute the uniform output vector size for each level of the shuffle tree,
+/// given the intervals of the input vectors at that level. The output vector
+/// size of a level is the size of the widest interval resulting from shuffling
+/// each pair of input vectors.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// Intervals:
+/// * Level 0: [0,6], [1,7], [2,8], [2,8]
+/// * Level 1: [0,7], [2,8]
+///
+/// Vector sizes:
+/// * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8,
+/// size_of([2,8] U [2,8] = [2,8]) = 7) = 8
+///
+/// * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9
+///
+void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() {
+ // Compute vector size for each level.
+ for (unsigned level = 0; level < numLevels; ++level) {
+ const auto ¤tLevelIntervals = inputIntervalsPerLevel[level];
+ unsigned currentVectorSize = 1;
+ for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) {
+ const auto &lhsInterval = currentLevelIntervals[i];
+ const auto &rhsInterval = currentLevelIntervals[i + 1];
+ unsigned combinedIntervalSize =
+ std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first +
+ 1;
+ currentVectorSize = std::max(currentVectorSize, combinedIntervalSize);
+ }
+ vectorSizePerLevel[level] = currentVectorSize;
+ }
+}
+
+void VectorShuffleTreeBuilder::dump() {
+ LLVM_DEBUG({
+ unsigned indLv = 0;
+
+ llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
+ ++indLv;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
+ ++indLv;
+ for (const auto &toElementsOp : toElementsDefs)
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n";
+ --indLv;
+
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Total levels: " << numLevels << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Vector sizes per level: [";
+ llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Input intervals per level:\n";
+ ++indLv;
+ for (const auto &[level, intervals] :
+ llvm::enumerate(inputIntervalsPerLevel)) {
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
+ << ": ";
+ llvm::interleaveComma(intervals, llvm::dbgs(),
+ [](const Interval &interval) {
+ llvm::dbgs() << "[" << interval.first << ","
+ << interval.second << "]";
+ });
+ llvm::dbgs() << "\n";
+ }
+ });
+}
+
+/// Compute the shuffle tree configuration for the given `vector.to_elements` +
+/// `vector.from_elements` input sequence. This method builds a balanced binary
+/// shuffle tree that combines pairs of input vectors at each level.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// build a tree that looks like:
+///
+/// %2 %1 %0 %0
+/// \ / \ /
+/// %2_1 = vector.shuffle %0_0 = vector.shuffle
+/// \ /
+/// %2_1_0_0 =vector.shuffle
+///
+/// The configuration comprises of computing the intervals of the input vectors
+/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and
+/// %2_1_0_0) and the output vector size for each level. For further details on
+/// intervals and output vector size computation, please, take a look at the
+/// corresponding utility functions.
+LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
+ // Initialize shuffle tree information based on its size.
+ assert(toElementsDefs.size() > 1 &&
+ "At least two 'vector.to_elements' ops are required");
+ numLevels = llvm::Log2_64(toElementsDefs.size());
+ vectorSizePerLevel.resize(numLevels, 0);
+ inputIntervalsPerLevel.reserve(numLevels);
+
+ computeInputVectorIntervals();
+ computeOutputVectorSizePerLevel();
+ dump();
+
+ return success();
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Code Generation Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the permutation mask for shuffling two input `vector.to_elements`
+/// ops. The permutation mask is the mapping of the input vector elements to
+/// their final position in the output vector, relative to the intermediate
+/// output vector of the `vector.shuffle` operation combining the two inputs.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// =>
+///
+/// // Level 0, vector length = 8
+/// %2_1 = PermutationShuffleMask(%2, %1) = [2,...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the nice documention! I think I get the basic idea, but I need to spend some more time getting into the details. Possible edge case to test out:
func.func @foo(
%a : vector<2xf32>,
%b : vector<1xf32>,
%c : vector<f32>,
%d : vector<f32>,
%e : vector<f32>) -> vector<6xf32> {
%0:2 = vector.to_elements %a : vector<2xf32>
%1:1 = vector.to_elements %b : vector<1xf32>
%2:1 = vector.to_elements %c : vector<f32>
%3:1 = vector.to_elements %d : vector<f32>
%4:1 = vector.to_elements %e : vector<f32>
%5 = vector.from_elements %0#0, %1#0, %2#0, %3#0, %4#0, %0#1 : vector<6xf32>
return %5 : vector<6xf32>
}
LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp, | ||
PatternRewriter &rewriter) const override { | ||
VectorType resultType = fromElementsOp.getType(); | ||
if (resultType.getRank() != 1 || resultType.isScalable()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this PR, but this rank check got me wondering. I would like to propose removing the implicit abillity to do a shape_cast out of vector.to_elements
and vector.from_elements
operations, so that they must act on rank-1 vectors. Actually I've thought this before for other Vector ops that do reshape-like things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check here comes from the limitations of vector.shape
to represent n-D shuffles, not really from the vector.to_/from_elements
. That limitation is actually more like a TODO that we should address at some point.
vector.to_/from_elements
semantics naturally extend to n-D vectors given the extraction/insertion order they define but, yes, I guess we could see it as an "implicit shape cast"...
I think, though, we've been moving towards the opposite direction. To have a cohesive multi-dimensional vector layer we need these "implicit shape casts" so that ops work nicely across the board without having to special-case 1-D from n-D... This supports even more the idea that shape casts are really no-ops...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the check for scalable vectors is on the same line ... :)
FromElementsOp
doesn't support scalable vectors. It would be good to add a comment - or better yet, replace failure with notifyMatchFailure
. 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the solution to this particular TODO is to have the final shuffle create a 1-D vector that feeds a shape_cast op that is the final replacement of the from_elements.
I think that ops could still work nicely across the board with more explicit shape_casts, without detracting from their n-D nature. But if the only operation that allows the rank to change is a shape_cast, many of the (quadratic) interactions between ops would be greatly simplified. A topic for another place and time, let me focus on this PR now!
|
||
// Duplicate the last vector if the number of `vector.to_elements` is odd to | ||
// simplify the shuffle tree algorithm. | ||
if (toElementsDefs.size() % 2 != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a check that it is a power of 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check refers to the number of vector.to_elements
inputs to combine so we want to be able to combine an arbitrary number of inputs. If that number is not even, we duplicate the las input to simplify the algorithm (the shuffle for that input would have the same input vector twice). Does it make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense. I was thinking that you might want to 'pad' all the way to a power of 2 so that shuffles at all depths were good. But your approach of a padding to a power of 2 at each level is more efficiently (O(logN) vs O(N) padding).
++inputIdx) { | ||
auto &interval = currentLevelIntervals[inputIdx]; | ||
const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2]; | ||
const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to power-of-2 comment: If previous level here had 3 intervals, current level has 2. If inputIdx = 1 here, you're accessing index 3 of previous intervals -- problem? That's why I think it might be necessary to ensure the number of starting intervals is a power of 2 (stricter than just being even).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I thought I had a check to duplicate the last input, similar to the one in the constructor, but I must have removed it at some point. Let me fix that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! Happy to clarify any questions you may have!
|
||
// Duplicate the last vector if the number of `vector.to_elements` is odd to | ||
// simplify the shuffle tree algorithm. | ||
if (toElementsDefs.size() % 2 != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check refers to the number of vector.to_elements
inputs to combine so we want to be able to combine an arbitrary number of inputs. If that number is not even, we duplicate the las input to simplify the algorithm (the shuffle for that input would have the same input vector twice). Does it make sense?
++inputIdx) { | ||
auto &interval = currentLevelIntervals[inputIdx]; | ||
const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2]; | ||
const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I thought I had a check to duplicate the last input, similar to the one in the constructor, but I must have removed it at some point. Let me fix that.
LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp, | ||
PatternRewriter &rewriter) const override { | ||
VectorType resultType = fromElementsOp.getType(); | ||
if (resultType.getRank() != 1 || resultType.isScalable()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check here comes from the limitations of vector.shape
to represent n-D shuffles, not really from the vector.to_/from_elements
. That limitation is actually more like a TODO that we should address at some point.
vector.to_/from_elements
semantics naturally extend to n-D vectors given the extraction/insertion order they define but, yes, I guess we could see it as an "implicit shape cast"...
I think, though, we've been moving towards the opposite direction. To have a cohesive multi-dimensional vector layer we need these "implicit shape casts" so that ops work nicely across the board without having to special-case 1-D from n-D... This supports even more the idea that shape casts are really no-ops...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - this is quite involved, but you've done a great job documenting and modularising it!
The high-level logic makes sense, but some of the finer details are still unclear to me. I’ll definitely need a few more passes through it 😅
As usual, I started with the tests to get a broad overview. I’ve left a few comments there - mostly suggesting more emphasis on edge cases. Maybe you could consider grouping the tests to make those clearer?
Also, replacing some (or most) uses of failure with notifyMatchFailure would be great 🙂 #selfDocumentingCode
FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs) | ||
: fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) { | ||
|
||
assert(fromElementsOp && "from_elements op is required"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this first assert required? fromElemOp
is a mandatory argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it can be null?
/// TODO: Implement mask compression to reduce the number of intermediate poison | ||
/// values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "mask compression"? I'm just curious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[1, 1, -1, -1, -1, -1, 4, -1] -> [1, 1, 4, -1]
/// 2. Each input vector to each level is used only once. | ||
/// 3. The number of levels in the tree is: | ||
/// ceil(log2(# `vector.to_elements` ops)). | ||
/// 4. Vectors at each level of the tree have the same vector length. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if the inputs to vector.to_elements
don't meet this criteria?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation bails out, although it should be easy to support...
// Avoid generating a shuffle tree for trivial `vector.to_elements` -> | ||
// `vector.from_elements` forwarding cases that do not require shuffling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a test for this?
LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp, | ||
PatternRewriter &rewriter) const override { | ||
VectorType resultType = fromElementsOp.getType(); | ||
if (resultType.getRank() != 1 || resultType.isScalable()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the check for scalable vectors is on the same line ... :)
FromElementsOp
doesn't support scalable vectors. It would be good to add a comment - or better yet, replace failure with notifyMatchFailure
. 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Drop to_from
(and all the variants of that) from function names. The test file already encodes the fact that all tests exercise the vector.to_elements
+ vector.from_elements
-> vector.shuffle
.
Also, what are the high-level categories in this test files? I see two:
- genuine shuffle (e.g.
@to_from_elements_single_input_shuffle
) - concat
@to_from_elements_shuffle_tree_concat_4x8_to_32
- concat with poison values (e.g.
@to_from_elements_shuffle_tree_concat_3x4_to_12
)
Anything else? If this is correct, it would be good to clarify this split.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's complicated but it general we have 3 categories: concatenations, broadcast and arbitrary shuffles. I'm using those tags in the function names. Poison vs non-poison is a bit orthogonal as poison may appear at any level of the tree (or not...)
// where L# refers to the level of the tree the shuffle belongs to, and SH# refers to | ||
// the shuffle index within that level. | ||
|
||
func.func @to_from_elements_single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comparing these function names its hard to tell what the difference is:
@to_from_elements_single_input_shuffle
,@from_elements_to_elements_single_shuffle
Wouldn't this be clearer:
@single_input
@multiple_inputs
or@two_inputs
Ultimately, it's:
// single input
%1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32>
vs
// two inputs
%2 = vector.from_elements %0#7, %1#0, %0#6, %1#1, %0#5, %1#2, %0#4, %1#3 : vector<8xf32>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic that I've followed is:
- Shuffles have two inputs so no need to specify "multi_input" shuffle every time. It's the "default".
- Single input shuffle is the exception so it's worth adding the "single_input" tag for it.
- Shuffle tree has multiple shuffles in general so no need to specify "multi_suffle". It's the "default"
- Single shuffle tree is the exception so it's worth adding the "single_shuffle" tag for it.
|
||
// ----- | ||
|
||
func.func @to_from_elements_shuffle_tree_concat_64x4_256( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO, this example is a bit too long and I'm not sure whether it adds much unique coverage. Do we believe that jumping from e.g. 4 to 64 input vectors changes much?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it changes the depth of the tree... the other tests are mostly generating 1 or 2 levels. I think it's important to test a large depth at least once.
/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6] | ||
/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7] | ||
/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8] | ||
/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is confusing to me. Looking at this sentence above:
/// The interval of an input vector is the range of positions in the final output that the input vector contributes to.
And:
/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
I see that the range for %0 is [0, 4]
(%0#
+ %0#4
), but:
/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
Could you add a bit more explanation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow the comment. Let me explain to see where the gap is:
Level 0 has 4 inputs (%2, %1, %0, %0, ...
- inputIntervalsPerLevel[0][0]
The first index corresponds to the level (0) and the second to the input at that level, so the input 0 at level 0 is %2
. That's why:
/// * inputIntervalsPerLevel[0][0] = interval(%2)
Could you help me understand what is missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this transformation in the first place? Is it to make lowering to llvm/spirv easier?
This is mostly implementing "2. Simplified Pattern Recognition and Optimization" in https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779. It's not about making the lowering easier but far more efficient both in terms of performance and compile time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback!
- Fixed value/op/interval duplication bug.
- Constrained the match to uniform vector inputs with rank.
- Added more tests
- Addressed misc. feedback
/// 2. Each input vector to each level is used only once. | ||
/// 3. The number of levels in the tree is: | ||
/// ceil(log2(# `vector.to_elements` ops)). | ||
/// 4. Vectors at each level of the tree have the same vector length. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation bails out, although it should be easy to support...
/// TODO: Implement mask compression to reduce the number of intermediate poison | ||
/// values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[1, 1, -1, -1, -1, -1, 4, -1] -> [1, 1, 4, -1]
/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6] | ||
/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7] | ||
/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8] | ||
/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow the comment. Let me explain to see where the gap is:
Level 0 has 4 inputs (%2, %1, %0, %0, ...
- inputIntervalsPerLevel[0][0]
The first index corresponds to the level (0) and the second to the input at that level, so the input 0 at level 0 is %2
. That's why:
/// * inputIntervalsPerLevel[0][0] = interval(%2)
Could you help me understand what is missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's complicated but it general we have 3 categories: concatenations, broadcast and arbitrary shuffles. I'm using those tags in the function names. Poison vs non-poison is a bit orthogonal as poison may appear at any level of the tree (or not...)
// where L# refers to the level of the tree the shuffle belongs to, and SH# refers to | ||
// the shuffle index within that level. | ||
|
||
func.func @to_from_elements_single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic that I've followed is:
- Shuffles have two inputs so no need to specify "multi_input" shuffle every time. It's the "default".
- Single input shuffle is the exception so it's worth adding the "single_input" tag for it.
- Shuffle tree has multiple shuffles in general so no need to specify "multi_suffle". It's the "default"
- Single shuffle tree is the exception so it's worth adding the "single_shuffle" tag for it.
|
||
// ----- | ||
|
||
func.func @to_from_elements_shuffle_tree_concat_64x4_256( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it changes the depth of the tree... the other tests are mostly generating 1 or 2 levels. I think it's important to test a large depth at least once.
This PR adds a new transformation that turns sequences of
vector.to_elements
andvector.from_elements
into a binary tree ofvector.shuffle
operations.(Related RFC: https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779).
Example:
The algorithm leverages the structured extraction/insertion information of
vector.to_elements
andvector.from_elements
operations and builds a set of intervals to determine the vector length that should be used at each level of the tree to combine the level inputs in pairs.There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along.